import torch
import os
import numpy as np
from DQN import ICN


class ICP:
    def __init__(self, args):
        self.args = args
        self.train_step = 0

        # create the network
        self.ICN = ICN(args)

        # build up the target network
        self.ICN_target = ICN(args)

        # load the weights into the target networks
        self.ICN_target.load_state_dict(self.ICN.state_dict())
        if self.args.cuda:
            self.ICN = self.ICN.cuda()
            self.ICN_target = self.ICN_target.cuda()


        # create the optimizer
        self.optim = torch.optim.Adam(self.ICN.parameters(), lr=self.args.lr)

        # create the dict for store the model
        if not os.path.exists(self.args.save_dir):
            os.makedirs(self.args.save_dir)
        # path to save the model
        self.model_path = self.args.save_dir
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)
        self.model_path = self.model_path
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)

        # 加载模型
        if os.path.exists(self.model_path + '/ICN.pkl'):
            self.ICN.load_state_dict(torch.load(self.model_path + '/ICN.pkl'))
            print('successfully loaded ICN: {}'.format(self.model_path + '/ICN.pkl'))

    # update
    def _update_target_network(self):
        self.ICN_target.load_state_dict(self.ICN.state_dict())


    # update the network
    def loss(self, transitions):
        done, r = [], []  # 用来装每个agent经验中的各项
        q_value, q_target = [], []

        for step in range(len(transitions)):
            q_target.append(transitions[step]['q_target'])
            q_value.append(transitions[step]['q_value'])
            r.append(transitions[step]['r'])
            done.append(transitions[step]['done'])
        # calculate the target Q value function

        if self.args.cuda:
            r = torch.tensor(np.array(r), dtype=torch.float32).cuda()
            done = torch.tensor(np.array(done), dtype=torch.float32).cuda()
        else:
            r = torch.tensor(np.array(r), dtype=torch.float32)
            done = torch.tensor(np.array(done), dtype=torch.float32)

        episode_len = len(transitions)

        assert r.shape == (episode_len, self.args.vec_env), r.shape

        q_value = torch.stack(q_value, dim=0).reshape(episode_len, self.args.vec_env, self.args.n_agents)
        sum_q = q_value.sum(dim=-1)

        q_target = torch.stack(q_target, dim=0).reshape(episode_len, self.args.vec_env, self.args.n_agents)

        q_next = torch.cat([q_target[1:],
                            torch.zeros_like(q_target[0].unsqueeze(0), device=q_target.device)],
                           dim=0)
        sum_q_next = q_next.sum(dim=-1)

        target_q = (r + self.args.gamma * sum_q_next * (1 - done))
        pad = torch.cat([done[0].unsqueeze(0), done[:-1]], dim=0)
        loss = ((target_q - sum_q) * (1 - pad)).pow(2).mean()

        return loss


    def trained_step(self, loss):

        self.optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.ICN.parameters(), self.args.grad_clip)
        self.optim.step()

        if self.train_step > 0 and self.train_step % self.args.target_update_rate == 0:
            self._update_target_network()

        if self.train_step > 0 and self.train_step % self.args.save_rate == 0:
            self.save_model(self.train_step)
        self.train_step += 1


    def save_model(self, train_step):
        num = str(train_step)
        model_path = os.path.join(self.args.save_dir, "model")
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        model_path = os.path.join(model_path)
        if not os.path.exists(model_path):
            os.makedirs(model_path)
        torch.save(self.ICN.state_dict(), model_path + '/' + num + '_ICN.pkl')


    def init_critic_hidden(self, batch_size):
        if self.args.cuda:
            return torch.zeros(self.args.rnn_layers, batch_size, self.args.hidden_size).cuda()
        else:
            return torch.zeros(self.args.rnn_layers, batch_size, self.args.hidden_size)
